import torch
import torch.nn as nn
import torch.optim as optim

from src.data.data_utils import choose_dataset
from src.low_rank_neural_networks.layer.tucker_conv_BUG_adaptive import Conv2d_tucker_BUG_adaptive
from src.low_rank_neural_networks.layer.mat_conv_BUG_adaptive import Conv2d_mat_BUG_adaptive
import warnings
import argparse
warnings.filterwarnings("ignore")

parser = argparse.ArgumentParser()
parser.add_argument("--tau", dest="tau", type=float, default=0.08)
parser.add_argument("--r", dest="start_rank_percent", type=float, default=0.3)
parser.add_argument("--m", dest="model", type=str, default='resnet')
parser.add_argument("--fact", dest="factorization", type=str, default='tucker')
parser.add_argument("--d", dest="data", type=str, default='cifar10')

options = parser.parse_args()

import src.low_rank_neural_networks.__init__ as g
g.factorization = options.factorization 
g.glob_start_rank_perc = float(options.start_rank_percent)
g.glob_tau = float(options.tau)

# Set the device (GPU or CPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the hyperparameters
batch_size = 128
learning_rate = 0.05
num_epochs = 100
fine_tune_steps = 0
momentum = 0.9
weight_decay = 5e-4

if options.data.lower() == 'tiny_imagenet':   ### training hyperparameters for tiny imagenet
    learning_rate = 0.01
    num_epochs = 50
    momentum = 0.9
    weight_decay = 5e-3
    batch_size = 128

if 'resnet' in options.model.lower():
    import src.low_rank_neural_networks.BUG.BUG_adaptive_ResNet as lr_BUG
elif 'vgg' in options.model.lower():
    import src.low_rank_neural_networks.BUG.BUG_adaptive_VGG as lr_BUG
elif 'alexnet' in options.model.lower():
    import src.low_rank_neural_networks.BUG.BUG_adaptive_alexnet as lr_BUG

print("Starting training with: ")
print("tau = " + str(options.tau))
print("start rank percent: " + str(options.start_rank_percent))

# -------- Dataset Selection -----------
if options.data.lower() == 'imagenet':
    num_classes = 1000
    datapath = "./imageNet/"
elif options.data.lower() == 'tiny_imagenet':
    num_classes = 200
    datapath = "./tiny_imagenet/"
else:
    num_classes = 10
    datapath = "./data/"

train_loader, val_loader, test_loader = choose_dataset(dataset_name=options.data.lower(), batch_size=batch_size,
                                                       datapath=datapath)

# Initialize the model
if options.model.lower() == 'vgg':
    model = lr_BUG.vgg16().to(device)
    print("Train VGG16")
elif options.model.lower() == 'resnet':
    model = lr_BUG.resnet18(num_classes = num_classes).to(device)
    print("Train ResNet18")
elif options.model.lower() == 'resnet50':
    model = lr_BUG.resnet50().to(device)
    print("Train ResNet50")
elif options.model.lower() == 'alexnet':
    model = lr_BUG.alexnet().to(device)
    print("Train AlexNet")

print(f' factorization : {options.factorization},dataset {options.data.lower()},model {options.model}')

# Define the loss function and optimizer
other_params = []
for name, module in model.named_modules():
    if not isinstance(module,Conv2d_tucker_BUG_adaptive) and not isinstance(module,Conv2d_mat_BUG_adaptive):
        for p in module.parameters():
            other_params.append(p)
other_params = list(set(other_params))
criterion = nn.CrossEntropyLoss()
optimizer_other = optim.SGD(other_params, lr=learning_rate, momentum = momentum,weight_decay = weight_decay)
scheduler_other = torch.optim.lr_scheduler.MultiStepLR(optimizer_other, milestones=[25,40], gamma=0.1)

lr_weights_Us = []
lr_weights_Cs = []
for lr_layer in lr_BUG.low_rank_layers:
    for u in lr_layer.Us:
        lr_weights_Us.append(u)
    lr_weights_Cs.append(lr_layer.C)

optimizer_Us = optim.SGD(lr_weights_Us, lr=learning_rate, momentum=momentum)
scheduler_Us = torch.optim.lr_scheduler.MultiStepLR(optimizer_Us, milestones=[25,40], gamma=0.1)

optimizer_Cs = optim.SGD(lr_weights_Cs, lr=learning_rate, momentum=momentum,weight_decay = weight_decay)
scheduler_Cs  = torch.optim.lr_scheduler.MultiStepLR(optimizer_Cs, milestones=[25,40], gamma=0.1)

if options.data.lower() == 'tiny_imagenet':
    scheduler_other = torch.optim.lr_scheduler.OneCycleLR(optimizer_other, max_lr=0.02, steps_per_epoch=len(train_loader),
                       epochs=num_epochs, div_factor=10, final_div_factor=10,
                       pct_start=10/num_epochs)
    scheduler_Us = torch.optim.lr_scheduler.OneCycleLR(optimizer_other, max_lr=0.02, steps_per_epoch=len(train_loader),
                       epochs=num_epochs, div_factor=10, final_div_factor=10,
                       pct_start=10/num_epochs)
    scheduler_Cs = torch.optim.lr_scheduler.OneCycleLR(optimizer_other, max_lr=0.02, steps_per_epoch=len(train_loader),
                       epochs=num_epochs, div_factor=10, final_div_factor=10,
                       pct_start=10/num_epochs)

#t = scheduler_Us.get_last_lr()
t = [ group['lr'] for group in optimizer_Cs.param_groups ][0]
# Training loop
for epoch in range(num_epochs):
    model.train()
    loss_total = 0.0
    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        ################### U step ###################
        for lr_layer in lr_BUG.low_rank_layers:
            lr_layer.copy_U()
            lr_layer.activate_grad_step("K")  # activate only U grads

        # Forward pass (K) steps
        outputs = model(data)  # k forward
        loss = criterion(outputs, targets)  # loss eval

        # Backward
        for p in lr_weights_Us:
            p.grad = None  # reset gradients of U
        loss.backward()

        # update U
        learning_rate = [ group['lr'] for group in optimizer_Us.param_groups ][0]
        for lr_layer in lr_BUG.low_rank_layers:
            lr_layer.step(dlrt_step="K", lr=learning_rate)  # U updates for each layer
            lr_layer.set_grad_zero()  # sets all low-rank gradients to zero

        ################### Core step ###################
        for lr_layer in lr_BUG.low_rank_layers:
            lr_layer.activate_grad_step("C")  # activate core taping, deactivate U taping

        outputs = model(data)  # core forward
        loss = criterion(outputs, targets)  # loss eval

        # Backward
        for p in model.parameters():
            p.grad = None
        loss.backward()
        loss_total+=float(loss.item())
        # update C
        optimizer_Cs.step()  #
        learning_rate = [ group['lr'] for group in optimizer_Cs.param_groups ][0]
        for lr_layer in lr_BUG.low_rank_layers:
            lr_layer.step(dlrt_step="C", lr=learning_rate)  # only truncates

        # all other layers update
        optimizer_other.step()

        # scheduler.step(loss.detach())  # adapt learning rate
        rs = []
        if (batch_idx + 1) % 100 == 0:
            print("current ranks:")
            for lr_layer in lr_BUG.low_rank_layers:
                rs.append(lr_layer.rank)
            print(rs)
            print(
                f"Epoch [{epoch + 1}/{num_epochs}], Step [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item():.4f}")
    scheduler_other.step()#(loss_total)
    scheduler_Us.step()#(loss_total)
    scheduler_Cs.step()#(loss_total)
    learning_rate = [ group['lr'] for group in optimizer_Cs.param_groups ][0]
    print("Current learning rate: " + str(learning_rate))
    # Test the model
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, targets in val_loader:
            data = data.to(device)
            targets = targets.to(device)

            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        accuracy = 100 * correct / total
        print(f"Accuracy of the network on the test images: {accuracy}%")

print("Training finished.")
print('='*100+'\n')
print("fine tuning starting...")

for lr_layer in lr_BUG.low_rank_layers:
            lr_layer.activate_grad_step("C")  # activate core taping, deactivate U taping
for epoch in range(fine_tune_steps):
    model.train()
    loss_total = 0.0
    for batch_idx, (data, targets) in enumerate(train_loader):
        data = data.to(device)
        targets = targets.to(device)

        outputs = model(data)  # core forward
        loss = criterion(outputs, targets)  # loss eval

        # Backward
        for p in model.parameters():
            p.grad = None
        loss.backward()
        loss_total+=float(loss.item())
        # update C
        optimizer_Cs.step()  #
        learning_rate = [ group['lr'] for group in optimizer_Cs.param_groups ][0]
        # all other layers update
        optimizer_other.step()

        # scheduler.step(loss.detach())  # adapt learning rate
        if (batch_idx + 1) % 100 == 0:
            print("current ranks:")
            for lr_layer in lr_BUG.low_rank_layers:
                print(lr_layer.rank)
            print(
                f"Epoch [{epoch + 1}/{num_epochs}], Step [{batch_idx + 1}/{len(train_loader)}], Loss: {loss.item():.4f}")
    scheduler_Cs.step()#(loss_total)
    learning_rate = [ group['lr'] for group in optimizer_Cs.param_groups ][0]
    print("Current learning rate: " + str(learning_rate))
    # Test the model
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for data, targets in val_loader:
            data = data.to(device)
            targets = targets.to(device)

            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

        accuracy = 100 * correct / total
        print(f"Accuracy of the network on the test images: {accuracy}%")